{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f9760a4-bfea-4def-a5be-1a00a4babba2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "593b47bf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from compel import Compel, ReturnedEmbeddingsType\n",
    "from diffusers import DiffusionPipeline\n",
    "import torch\n",
    "\n",
    "device='cuda'\n",
    "pipeline = DiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", variant=\"fp16\", use_safetensors=True, torch_dtype=torch.float16).to(device)\n",
    "\n",
    "compel = Compel(tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22f5e38b",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# upweight \"ball\"\n",
    "prompt = \"a cat playing with a ball++ in the forest\"\n",
    "negative_prompt = \"deformed, ugly\"\n",
    "conditioning, pooled = compel([prompt, negative_prompt])\n",
    "print(conditioning.shape, pooled.shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "26d34f881a1f1fc8"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "821a38a5",
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# generate image\n",
    "image = pipeline(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], \n",
    "                 negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2],\n",
    "                 num_inference_steps=24, width=768, height=768).images[0]\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Long prompts"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "2a0aa04d1c6e3cd6"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "compel = Compel(tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , \n",
    "                text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], \n",
    "                returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, \n",
    "                requires_pooled=[False, True],\n",
    "               truncate_long_prompts=False)\n",
    "\n",
    "prompt = \"a cat playing with a ball++ in the forest\"\n",
    "negative_prompt = \"a long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long negative prompt\"\n",
    "conditioning, pooled = compel([prompt, negative_prompt])\n",
    "print(conditioning.shape, pooled.shape)\n",
    "\n",
    "image = pipeline(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], \n",
    "                 negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2],\n",
    "                 num_inference_steps=24, width=768, height=768).images[0]\n",
    "image"
   ],
   "metadata": {
    "collapsed": false,
    "is_executing": true
   },
   "id": "e9941e0e42e76b6"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Sequential cpu offload"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "3ed69a9a9e07eae8"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1685e7f8-1e41-46c1-92f8-9dcfec025db0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "\n",
    "compel = Compel(tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , \n",
    "                text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], \n",
    "                returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, \n",
    "                requires_pooled=[False, True],\n",
    "                device=\"cuda\")\n",
    "\n",
    "pipeline.enable_sequential_cpu_offload()\n",
    "prompt = \"a cat playing with a ball++ in the forest\"\n",
    "negative_prompt = \"deformed, ugly\"\n",
    "conditioning, pooled = compel([prompt, negative_prompt])\n",
    "print(conditioning.shape, pooled.shape)\n",
    "\n",
    "image = pipeline(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], \n",
    "                 negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2],\n",
    "                 num_inference_steps=24, width=768, height=768).images[0]\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Different prompts for different encoders"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "6a96381af5bdb52b"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compel1 = Compel(\n",
    "    tokenizer=pipeline.tokenizer,\n",
    "    text_encoder=pipeline.text_encoder,\n",
    "    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,\n",
    "    requires_pooled=False,\n",
    ")\n",
    "\n",
    "compel2 = Compel(\n",
    "    tokenizer=pipeline.tokenizer_2,\n",
    "    text_encoder=pipeline.text_encoder_2,\n",
    "    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,\n",
    "    requires_pooled=True,\n",
    ")\n",
    "\n",
    "conditioning1 = compel1(prompt1)\n",
    "conditioning2, pooled = compel2(prompt2)\n",
    "conditioning = torch.cat((conditioning1, conditioning2), dim=-1)\n",
    "\n",
    "image = pipeline(prompt_embeds=conditioning, pooled_prompt_embeds=pooled, num_inference_steps=30).images[0]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ab10c0d2d52603c4"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
